import os
import json
import numpy as np
import torch
import tqdm
import hydra
import hydra.utils as hu
from transformers import set_seed
from torch.utils.data import DataLoader
from src.utils.collators import DataCollatorWithPaddingAndCuda
from src.models.biencoder import BiEncoder
import pdb


def submodular_diverse_select_no_log(demo_embeds: np.ndarray,
                                     test_embed: np.ndarray,
                                     k: int,
                                     lambd: float) -> list:
    """
    Greedy submodular selection without logging.
    """
    n, d = demo_embeds.shape
    selected = []
    V_S = 0.02 * np.eye(d)
    candidate_indices = set(range(n))

    for _ in range(k):
        invV_S = np.linalg.inv(V_S)
        best_val = -1e9
        best_idx = None
        for i in candidate_indices:
            x = demo_embeds[i]
            numerator = (test_embed @ invV_S @ x) ** 2
            denom = 1.0 + (x @ invV_S @ x)
            score = numerator / denom + lambd * denom
            if score > best_val:
                best_val = score
                best_idx = i
        selected.append(best_idx)
        candidate_indices.remove(best_idx)
        V_S += np.outer(demo_embeds[best_idx], demo_embeds[best_idx])
    return selected


@hydra.main(config_path="configs", config_name="submodular_retriever")
def main(cfg):
    set_seed(cfg.seed if hasattr(cfg, 'seed') else 42)
    
    # initialize bi-encoder model
    model_config = hu.instantiate(cfg.model_config)
    if cfg.pretrained_model_path != 0:
        print(f"Loading model from: {cfg.pretrained_model_path}")
        model = BiEncoder.from_pretrained(cfg.pretrained_model_path, config=model_config)
    else:
        model = BiEncoder(model_config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # prepare index dataset and loader
    index_reader = hu.instantiate(cfg.index_reader)
    tokenizer = index_reader.tokenizer
    collator = DataCollatorWithPaddingAndCuda(tokenizer=tokenizer, device=device)
    index_loader = DataLoader(index_reader,
                              batch_size=cfg.batch_size,
                              collate_fn=collator)

    # generate index embeddings and metadata
    index_embeds_list = []
    index_metadata = []
    for batch in tqdm.tqdm(index_loader, desc="Encoding index passages"):
        with torch.no_grad():
            embeds = model.encode(batch["input_ids"], batch["attention_mask"])
        index_embeds_list.append(embeds.cpu().numpy())
        index_metadata.extend(batch.get("metadata", []).data)
    index_embeds = np.vstack(index_embeds_list)

    # prepare query dataset and loader
    query_reader = hu.instantiate(cfg.dataset_reader)
    query_loader = DataLoader(query_reader,
                              batch_size=cfg.batch_size,
                              collate_fn=collator)

    # generate query embeddings and metadata
    query_embeds_list = []
    query_metadata = []
    for batch in tqdm.tqdm(query_loader, desc="Encoding queries"):
        with torch.no_grad():
            embeds = model.encode(batch["input_ids"], batch["attention_mask"])
        query_embeds_list.append(embeds.cpu().numpy())
        query_metadata.extend(batch.get("metadata", []).data)
    query_embeds = np.vstack(query_embeds_list)

    # run submodular selection for each query
    results = []
    for idx, (test_embed, meta) in tqdm.tqdm(enumerate(zip(query_embeds, query_metadata)), total=len(query_embeds)):
        if cfg.run_for_n_samples and idx >= cfg.run_for_n_samples:
            break
        selected_idxs = submodular_diverse_select_no_log(
            demo_embeds=index_embeds,
            test_embed=test_embed,
            k=cfg.num_ice,
            lambd=cfg.lambd,
        )

        # fetch and copy the original entry
        orig_entry = query_reader.dataset_wrapper[meta["id"]].copy()
        orig_entry["ctxs"] = selected_idxs
        # also include ctxs_candidates identical to individual selections
        orig_entry["ctxs_candidates"] = [[i] for i in selected_idxs]
        results.append(orig_entry)

    # write output
    os.makedirs(os.path.dirname(cfg.output_file), exist_ok=True)
    with open(cfg.output_file, "w") as fout:
        json.dump(results, fout, indent=2)



if __name__ == "__main__":
    main()
